#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 16 19:11:19 2019

@author: hyungmokson
"""

## -*- coding: utf-8 -*-
#"""
#Created on Tue Jun 18 18:06:06 2019
#
#@author: SonHyungmok
#"""

import matplotlib as mpl
from matplotlib import rcParams
from matplotlib import rc
from scipy.integrate import odeint
from skimage.transform import resize

mpl.use("pgf")

print("MATPLOTLIB STUFF:")
print("matplotlib version: " + mpl.__version__)
print("matplotlib directory: " + mpl.__file__)
print(" ")
def figsize(scale):
    fig_width_pt = 350.                          # Get this from LaTeX using \the\textwidth
    inches_per_pt = 1.0/72.27                       # Convert pt to inch
    golden_mean = 0.8#(np.sqrt(5.0)-1.0)/2.0            # Aesthetic ratio (you could change this)
    fig_width = fig_width_pt*inches_per_pt*scale    # width in inches
    fig_height = fig_width*golden_mean              # height in inches
    fig_size = [fig_width,fig_height]
    return fig_size

#
pgf_with_latex = {                      # setup matplotlib to use latex for output
    "pgf.texsystem": "pdflatex",        # change this if using xetex or lautex
    "text.usetex": True,                # use LaTeX to write all text
#    "font.family": "serif",
#    "font.serif": [],                   # blank entries should cause plots to inherit fonts from the document
#    "font.sans-serif": [],
#    "font.monospace": [],
    "axes.linewidth": 1,
    "axes.labelsize": 7,               # LaTeX default is 10pt font.
#    "text.fontsize": 10,
    "legend.fontsize": 7,               # Make the legend/label fonts a little smaller
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "figure.figsize": figsize(0.9),     # default fig size of 0.9 textwidth
    "pgf.preamble": [
#        r"\usepackage[utf8x]{inputenc}",    # use utf8 fonts becasue your computer can handle it :)
#        r"\usepackage[T1]{fontenc}",
        r"\usepackage{amsmath}",
        r"\usepackage{amstext}", 
        r"\usepackage[utf8x]{inputenc}",
        r"\usepackage[T1]{fontenc}",
        r"\usepackage{cmbright}",# plots will be generated using this preamble
        ]
    }

pgf_with_latex_simple = {                      # setup matplotlib to use latex for output
    "pgf.texsystem": "pdflatex",        # change this if using xetex or lautex
    "text.usetex": True,
#    "font.family": "serif",
#    "font.serif": [],                   # blank entries should cause plots to inherit fonts from the document
#    "font.sans-serif": [],
#    "font.monospace": [],
    "axes.labelsize": 7,               # LaTeX default is 10pt font.
#    "text.fontsize": 10,
    "legend.fontsize": 7,               # Make the legend/label fonts a little smaller
    "xtick.labelsize": 7,
    "ytick.labelsize": 7              # use LaTeX to write all text]
    }

rcParams.update(pgf_with_latex)
rc('text', usetex=True)
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Arial']

###########################################################################################################
###########################################################################################################
## from here, my code


#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 12 04:17:02 2019

@author: hyungmokson
"""

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue May 28 00:41:50 2019

@author: hyungmokson
"""

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon May 27 17:15:28 2019

@author: hyungmokson
"""

import numpy as np
from numpy import linalg
import matplotlib.pyplot as plt
import os
from os import listdir
from os.path import isfile, join
from scipy.optimize import curve_fit


if __name__ == "__main__":
    working_dir = os.getcwd()
    sub = os.listdir(working_dir)
#    print(working_dir)
#    print(sub)
    
    cooling4V_dir = working_dir
    mol_temp_correction = .0
    na_temp_correction = .0
    plotpath = join(working_dir, 'plot save')
    
    def outputData(main, fileName, flg):
        if flg == 0:
            path = join(main, 'without Na')
        else:
            path = join(main, 'with Na')
        
        path = join(path, 'csv')
        name = str(fileName) + '.csv'
#        name = str(fileName) + '.txt'
        temp = np.genfromtxt(join(path, name), delimiter=',', dtype=float, comments = "#", names=('var', 'data', 'err'))
#        temp = np.genfromtxt(join(path, name), delimiter='\t', dtype=float, comments = "#", names=('var', 'data', 'err'))
        x = np.array(temp['var'])
        y = np.array(temp['data'])
        e = np.array(temp['err'])
        return x,y,e
    
    main = cooling4V_dir
    
    ######################
    ##### without Na #####
    ######################
    t_molnum_no, n_molnum_no, err_molnum_no = outputData(main, 'mol num', 0)
    t_moltemp_no, T_moltemp_no, err_moltemp_no = outputData(main, 'mol temp', 0)
        ######################
    ####### with Na ######
    ######################    
    t_molnum, n_molnum, err_molnum = outputData(main, 'mol num', 1)
    t_nanum, n_nanum, err_nanum = outputData(main, 'na num', 1)
    t_moltemp, T_moltemp, err_moltemp = outputData(main, 'mol temp', 1)
    t_natemp, T_natemp, err_natemp  = outputData(main, 'na temp', 1)

    offset = 0.
    plotFlg = 0 ##  drawing guide lines
    
    idx = 0
    t_molnum_no = (t_molnum_no[idx:] + offset)/1000.
    n_molnum_no = n_molnum_no[idx:]
    err_molnum_no = err_molnum_no[idx:]
    
    molnum_no = []
    molnum_no.append(t_molnum_no)
    molnum_no.append(n_molnum_no)
    molnum_no.append(err_molnum_no)

    idx = 0
    t_moltemp_no = (t_moltemp_no[idx:] + offset)/1000.
    T_motemp_no = T_moltemp_no[idx:] + mol_temp_correction 
    err_moltemp_no = err_moltemp_no[idx:]
    
    moltemp_no = []
    moltemp_no.append(t_moltemp_no)
    moltemp_no.append(T_motemp_no)
    moltemp_no.append(err_moltemp_no)

    idx = 0
    t_molnum = (t_molnum[idx:] + offset)/1000.
    n_molnum = n_molnum[idx:]
    err_molnum = err_molnum[idx:]
    
    molnum = []
    molnum.append(t_molnum)
    molnum.append(n_molnum)
    molnum.append(err_molnum)
    
    idx = 0
    t_nanum = (t_nanum[idx:] + offset)/1000.
    n_nanum= n_nanum[idx:]
    err_nanum = err_nanum[idx:]
    
    nanum = []
    nanum.append(t_nanum)
    nanum.append(n_nanum)
    nanum.append(err_nanum)

    idx = 0
    t_moltemp = (t_moltemp[idx:] + offset)/1000.
    T_moltemp =  T_moltemp[idx:] + mol_temp_correction
    err_moltemp= err_moltemp[idx:]

    moltemp = []
    moltemp.append(t_moltemp)
    moltemp.append(T_moltemp)
    moltemp.append(err_moltemp)


    idx = 0
    t_natemp = (t_natemp[idx:] + offset)/1000.
    T_natemp =  T_natemp[idx:] + na_temp_correction
    err_natemp= err_natemp[idx:]
    
    natemp = []
    natemp.append(t_natemp)
    natemp.append(T_natemp)
    natemp.append(err_natemp)

#    def weightedAvgStd(data, err):
#        temp = data * 1./err
#        norm_factor = np.sum(1./err)
#        avg = np.sum(temp)/norm_factor
#        std = np.sqrt(np.size(err))/norm_factor
#        return avg, std
    
#    def weightedAvgStd(data, err):
#        temp = data * (err**-2)
#        norm_factor = np.sum(err**-2)
#        avg = np.sum(temp)/norm_factor
#        std = 1./np.sqrt(norm_factor)
#        return avg, std
    
    def weightedAvgStd(data, err):
        avg = np.average(data)
        stat_err = np.std(data)/np.sqrt(len(data)-1)
        uncert_err = np.sqrt(np.sum(err**2))/len(err)
        std = np.sqrt(stat_err**2 + uncert_err**2)
        return avg, std

    
    ################################
    ######## plot Na num ###########
    ################################
    num_factor = 1e5
#    plt.errorbar(t_nanum, n_nanum/num_factor, yerr = err_nanum/num_factor, fmt = 'o', mew=2)
#    plt.xlabel("Hold (sec)")
#    plt.ylabel("Na state 8 with mol " + str(r"$\times \rm 10^{5}$"))       
#    plt.close()
    l = 0.3 + offset/1000. ## 200ms + offset
    idx = np.where(np.abs(t_nanum-l) < 1e-3)
    i = idx[0][0]
    idx_cut = i 
    time_cut = t_nanum[idx_cut]

    nanum_avg = np.average(n_nanum[:idx_cut+1])
    err = err_nanum[:idx_cut+1]

    nanum_std = np.sqrt(np.sum(err**2))/np.size(err)
    w_avg, w_std = weightedAvgStd(n_nanum[:idx_cut+1], err)
    ################
    nanum_std = w_std
    nanum_avg = w_avg
    
    print()   
    print("Na number avg upto " + str(time_cut) + " sec :  " + str(nanum_avg) + "  " + str(w_avg))  
    print("Na number std upto " + str(time_cut) + " sec :  " + str(nanum_std) + "  " + str(w_std))  
    print()
    
    natemp_avg, natemp_std = weightedAvgStd(T_natemp[:idx_cut + 1],  err_natemp[:idx_cut+1])

    print()   
    print("Na temp avg upto " + str(time_cut) + " sec :  " + str(natemp_avg))
    print("Na temp std upto " + str(time_cut) + " sec :  " + str(natemp_std))
    print()

    natemp_avg2, natemp_std2 = weightedAvgStd(T_natemp[idx_cut:], err_natemp[idx_cut:])

    print("Na temp from " + str(time_cut) + " sec :  " + str(natemp_avg2))
    print("Na temp std from " + str(time_cut) + " sec :  " + str(natemp_std2))
    print()
    print("Na temp avg for total = " + str(np.average(T_natemp)))
    print("Na temp std for total = " + str(np.sqrt(np.sum(err_natemp**2))/np.size(err_natemp)))
    print("")
    print("==== How much Na temp increased after thermalization (before & after 0.3sec) ====")
    print("temp increase by " + str(natemp_avg2 - natemp_avg))
    print("temp increase by " + str((natemp_avg2 - natemp_avg)/natemp_avg * 100) + " %")
    ###############
    natemp_init, natemp_init_std = weightedAvgStd(T_natemp[:idx_cut + 1], err_natemp[:idx_cut + 1])
    natemp_avg, natemp_std = weightedAvgStd(T_natemp[idx_cut:], err_natemp[idx_cut:])
    diff_std = np.sqrt(natemp_init_std**2 + natemp_std**2)

    print("")    
    print("")
    print("Na temp weighted average upto " + str(time_cut) + "sec =  " +str(natemp_init) + " / " + str(natemp_init_std))
    print("Na temp weighted average from " + str(time_cut) + "sec =  " +str(natemp_avg) + " / " + str(natemp_std))
    print("==== How much Na temp increased after thermalization (before & after 0.3sec) ====")
    print("temp increase by " + str(-natemp_init + natemp_avg) + " / " + str(diff_std))
    print("temp increase by " + str(-(natemp_init - natemp_avg)/natemp_avg * 100))
    print()
    print()
    print()
    
    plt.close()
    
    ###########################################################################
    ####### thermalization ##############
    ###########################################################################
    ## molecule settled temperature
    l =  0.3 + offset/1000.
    ii = np.where(np.abs(t_moltemp - l) < 1e-10)
    i = ii[0][0]
    i_cut = i 
    
    ####### 0.6sec data is definitely outlier######
    temp =  T_moltemp[i_cut:]
    temp_err =  err_moltemp[i_cut:]

    a = temp[:2]
    a_err = temp_err[:2]
    b = temp[3:]
    b_err = temp_err[3:]
    temp = np.concatenate((a,b))
    temp_err = np.concatenate((a_err,b_err))
    moltemp_avg, moltemp_std = weightedAvgStd(temp, temp_err)

    t = np.linspace(0, 1.2, 5000)
    moltemp_min = np.amin(T_moltemp)
    moltemp_max = np.amax(T_moltemp)
       
    ###########################################
    #### plot data as scatter + error bar #####
    ###########################################
    plt.errorbar(t_moltemp_no,T_moltemp_no, yerr = err_moltemp_no, fmt = 'o', mew=3)
    plt.errorbar(t_moltemp,T_moltemp, yerr = err_moltemp, fmt = 'o', mew=3)
    plt.errorbar(t_natemp,T_natemp, yerr = err_natemp, fmt = 'o', mew=3)
  

    plt.legend(['NaLi without Na', 'NaLi', 'Na'])
    plt.ylim([1, 2.55])
    plt.xlim([0, 1.1])

    ##########################################################################
    ###### fit the mol temp to simple exponential
    ###### upto 0.2 sec, there's no Na decay
    t = np.linspace(0., 1.2, 5000)
    def onebodyloss_offset(t, temperature, l, offset):
        return temperature * np.exp(-l * t) + offset

    temp_avg = np.average([moltemp_avg, natemp_avg])
    temp_std = np.sqrt(moltemp_std**2 + natemp_std**2)/2
                      
    temp_avg = moltemp_avg
    temp_std = moltemp_std
    print()
    print("The average of settled temp = " + str(temp_avg))
    print("The std of settled temp = " + str(temp_std))
    ul_factor = 1. 
    dl_factor = 1. 
    print()
    b = ([0.3, 0, temp_avg - dl_factor*temp_std], [.8, 100, temp_avg + ul_factor * temp_std])
    print("upper limit = " + str(temp_avg + ul_factor * temp_std))
    print("lower limit = " + str(temp_avg - dl_factor * temp_std))
    print("")
    
    idx = 0
    popt_one, pcov_one = curve_fit(onebodyloss_offset, t_moltemp[idx:], T_moltemp[idx:], p0 =[0.5, 1 ,temp_avg],  bounds=b, sigma = err_moltemp[idx:])
    popt_one_off = popt_one
    pcov_one_off = pcov_one
    print("")
    print("")
    print("**** Molecule thermalization SIMPLE EXPONENTIAL  fit result *****")
    print()
    print("Exponential fit result = " + str(popt_one))
    print("")
    err = np.sqrt(np.diag(pcov_one))
    init_temp = popt_one[0] + popt_one[2]
    thermo_rate = popt_one[1]
    init_temp_err = np.sqrt(err[0]**2 + err[2]**2)
    thermo_rate_err = err[1]
    print("final temp / std = " + str(popt_one[2]) + " / " +str(err[2]))
    print("std/avg * 100 = " + str(err[2]/popt_one[2] *100))
    print("")
    print("initial temp / std = " + str(init_temp) + " / " + str(init_temp_err))
    print("std/avg *100 = " + str(init_temp_err/init_temp *100))
    print("")
    print("diff of init & final temp = " + str(init_temp - popt_one[2]) + " / " + str(np.sqrt(init_temp_err**2 + err[2]**2)))
    print("")
    print("thermalization rate / std = " + str(thermo_rate) + " / " + str(thermo_rate_err))
    print("std/avg * 100 = " + str(thermo_rate_err/thermo_rate * 100))
    print("")
    print("")
    plt.plot(t, onebodyloss_offset(t, popt_one[0], popt_one[1], popt_one[2]), 'b--')
 
    plt.xlabel("Hold (sec)")
    plt.ylabel(r"$Temperature\ (\mu K)$")
    plt.close()
     
    #########################################################
    #### what is the collision cross section ######
    #########################################################
    a = (1596*1e-3)/2.
    sigma = 165.95 *6.5/1.5 ##um
    sigma = 720 ##um
    print("")
    print("Gaussian sigma = " + str(sigma) + " um")
    l_eff = np.sqrt(8./np.pi) * sigma
    particle_per_site = l_eff/a

    ## Na trap frequencies for 4V trap  ##
    na_trapfreq_1 = 600 * np.sqrt(4./5.)
    na_trapfreq_2 = 462.5 * np.sqrt(4./5.)
    na_trapfreq_3 = 33.8*1e3 
    
    na_trapfreq_4V = []
    na_trapfreq_4V.append(na_trapfreq_1)
    na_trapfreq_4V.append(na_trapfreq_2)
    na_trapfreq_4V.append(na_trapfreq_3)
    
    factor = (1./np.pi**3) *(np.pi * 1.38 * 6./(23*10**-3))**(3./2.)
    print(factor)
    
#    kappa = 301.33 * 1e6/(np.prod(na_trapfreq_4V)) ## cm^3 ##
    kappa = factor * 1e6/(np.prod(na_trapfreq_4V)) ## cm^3 ##
    na_volume_4V = kappa * (natemp_init*10**-6)**(3./2.) ## cm^3 ##
    err_na_volume_4V = 3./2. * kappa * na_volume_4V * natemp_init_std/natemp_init
    
#    na_density_4V = nanum_avg/na_volume_4V
    na_density_4V = nanum_avg/na_volume_4V/particle_per_site
    err_na_density_4V = na_density_4V * np.sqrt((err_na_volume_4V/na_volume_4V)**2 + (nanum_std/nanum_avg)**2)

    print("")
    print("Na density at 4V = " + str(na_density_4V/1e12) + " x 10^12 (#/cm^3)")
    print("Na density std = " + str(err_na_density_4V/1e12) + " x 10^12 (#/cm^3)")
    print("")
                                  
    navelocity = np.sqrt(3 * 1.38 * (natemp_init*1e-6)/(23 * 1e-3/6.02)) * 100. ## cm/s
    err_navelocity = .5 * navelocity/natemp_init * natemp_init_std
    
    print("Na velocity at temperature of " + str(natemp_init) + " is " + str(navelocity) + " cm/s")
    print("Na velocity std at temperature of " + str(natemp_init) + " is " + str(err_navelocity) + " cm/s")
    
    collision_crosssection = thermo_rate/(na_density_4V * navelocity)
    err_collision_crosssection = collision_crosssection * np.sqrt((thermo_rate_err/thermo_rate)**2 + (err_navelocity/navelocity)**2 + (err_na_density_4V/na_density_4V)**2)
    
    print("")
    print("NaLi-Na collisional cross section (from 4V) = " + str(collision_crosssection) + " cm^2")
    print("NaLi-Na collisional cross section std (from 4V) = " + str(err_collision_crosssection) + " cm^2")
    print("")
    
    scatt_length = np.sqrt(collision_crosssection/np.pi) ## cm
    err_scatt_length = scatt_length * err_collision_crosssection/(2. * collision_crosssection) ##cm
    print("NaLi-Na scattering length = " + str(scatt_length * 1e7) + " nm")
    print("NaLi-Na scattering length std = " + str(err_scatt_length * 1e7) + " nm")
    print("")


    def twobodyloss(t, N, beta):
        return 1./(beta * t + 1./N)
    
    b = ([18000, 0], [25000, 1])
    popt_two, pcov_two = curve_fit(twobodyloss, t_molnum_no, n_molnum_no, p0 =[20000., 0.1],  bounds=b, sigma = err_molnum_no)
    
    n_mol = popt_two[0]
    beta = popt_two[1]
#    gamma_mol = popt_one[1]

    n_mol_err = np.sqrt(np.diag(pcov_two))[0]
    beta_err = np.sqrt(np.diag(pcov_two))[1]

    plt.errorbar(t_molnum_no, n_molnum_no, yerr = err_molnum_no, fmt = 'o', mew= 3, c = 'blue')
    plt.errorbar(t_molnum, n_molnum, yerr = err_molnum, fmt = 'o', mew=3, c = 'red')
    plt.legend(['without Na', 'with Na'])

    ######################################
    ##### two body loss fit with Na ######
    ######################################
    plt.plot(t, twobodyloss(t, popt_two[0], popt_two[1]), 'b--')
    print()
    print("mol num / std = " + str(n_mol) + " / " + str(n_mol_err))
    print("beta /std = " + str(beta) + " / " + str(beta_err))
    print()
    twobodyrate = n_mol * beta
    twobodyrate_err = twobodyrate * np.sqrt((beta_err/beta)**2 + (n_mol_err/n_mol)**2)
    print("two body rate / std = " + str(twobodyrate) + " / " + str(twobodyrate_err))
    print()
    print()

    ############################################################
    #######################################################]\
    ####
    ############################################################
    #### using the Rudy Grim paper ####
    print("")
    print(" **** Referencing Rudy Grim's Dy-K paper ****")
    m_na = 23.
    m_mol = (23.+6.)
#    (1/1596. - 1/671.)/(1/1596. - 1/589.)
    pol_ratio_mol_to_na = 2.6
    print("polarizability ratio a_mol/a_na = " + str(pol_ratio_mol_to_na))
    A = (1 + m_na/m_mol * init_temp/natemp_init)**(1/2.) * (1 + 1./pol_ratio_mol_to_na * init_temp/natemp_init)**(-3/2.)
    v_rel = np.sqrt(8 * 1.38 * 10**-23/np.pi * (init_temp *10**-6/(m_mol*10**-3/(6.02*10**23)) + natemp_init*10**-6/(m_na*10**-3/(6.02*10**23))))    
    err_v_rel = v_rel/(init_temp *10**-6/(m_mol*10**-3/(6.02*10**23)) + natemp_init*10**-6/(m_na*10**-3/(6.02*10**23))) * np.sqrt((init_temp_err* 10**-6/(2 * (m_mol*10**-3/(6.02*10**23))))**2 + (natemp_init_std* 10**-6/(2 * (m_na*10**-3/(6.02*10**23))))**2 )
    print("factor A = " + str(A))
    print("relative velocity = " + str(v_rel * 100) + " cm/s")
    print("err relative velocity = " + str(err_v_rel * 100) + " cm/s")

    
#    mol_trapfreq_4V = []
#    mol_trapfreq_1 = 800 * np.sqrt(4/4.5)
#    mol_trapfreq_2 = 650 * np.sqrt(4/4.5)
#    trapfreq_ratio_mol_to_na = np.average([800/600., 650.462.5])
#    mol_trapfreq_3 = 33.8 * 1e3  * trapfreq_ratio_mol_to_na
#    
#    moltrapfreq_4V.append(mol_trapfreq_1)
#    moltrapfreq_4V.append(mol_trapfreq_2)
#    moltrapfreq_4V.append(mol_trapfreq_3)
#    mol_ omega_mean = ((2*np.pi)**3 * np.prod(mol_trapfreq_4V)**1.3
    
                        
#    particle_per_site = 566
        
    a = (1596*1e-3)/2.
#    sigma = 200. *6.45/1.5 ##um
#    sigma = 239.5 * 6.5/1.5
#    sigma = 165.95 *6.5/1.5 ##um
    print("")
    print("Gaussian sigma = " + str(sigma) + " um")
    l_eff = np.sqrt(8./np.pi) * sigma
    particle_per_site = l_eff/a
    print("")
    print("the number of lattice sites = "  + str(particle_per_site))
    print("")
    na_omega_mean = 2*np.pi * (np.prod(na_trapfreq_4V))**(1./3.)
#    print("look at here = " + str(2 * np.pi * 6.02 * 1.38 * natemp_avg * 10**-6/(m_na * 10**-3) * (1 + 1/pol_ratio_mol_to_na * init_temp/natemp_avg)))
    overlap = na_omega_mean**3 * (2 * np.pi * 6.02 * 1.38 * natemp_init * 10**-6/(m_na * 10**-3))**(-3./2.) * (1 + (1/pol_ratio_mol_to_na) * (init_temp/natemp_init))**(-3./2.)
    err_overlap = overlap * 1.5 * np.sqrt((natemp_init_std* 10**-6 * ((2 * np.pi * 6.02 * 1.38/(m_na * 10**-3)*(2 * np.pi * 6.02 * 1.38 * natemp_init * 10**-6/(m_na * 10**-3))**(-1) - (init_temp/natemp_init)* (natemp_init* 10**-6)**(-1) *(1/pol_ratio_mol_to_na) *(1 + (1/pol_ratio_mol_to_na) * (init_temp/natemp_init))**(-1))))**2 + (init_temp_err * 10**-6*(1 + (1/pol_ratio_mol_to_na) * (init_temp/natemp_init))**(-1))**2)
    print("overlap integral/(molecule num * na num) = " + str(overlap/1e6/1e12))
    print("err overlap integral/(molecule num * na num) = " + str(err_overlap/1e6/1e12))

    eff_na_density = overlap * nanum_avg/particle_per_site
    err_eff_na_density = eff_na_density * np.sqrt((nanum_std/nanum_avg)**2 + (err_overlap/overlap)**2)
    print("Na_num/particle_per_site * overlap = " + str(eff_na_density/1e6/1e12) + " x 10^12 (1/cm^3)")
    print("err Na_num/particle_per_site * overlap = " + str(err_eff_na_density/1e6/1e12) + " x 10^12 (1/cm^3)")
    print("ratio = " + str(err_eff_na_density/eff_na_density))

    overlap_int = n_mol * nanum_avg * overlap
    overlap_int_per_site = overlap_int/(particle_per_site)**2
    err_overlap_int_per_site = overlap_int_per_site * np.sqrt((n_mol_err/n_mol)**2 + (nanum_std/nanum_avg)**2 + (err_overlap/overlap)**2)
    print("")
    print("overlap integral = " + str(overlap_int/1e6/1e12) + " x 10^12 (1/cm^3)")
    print("overlap integral per lattice site = " + str(overlap_int_per_site/1e6/1e12) + " x 10^12 (1/cm^3)")
    print("err overlap integral per lattice site = " + str(err_overlap_int_per_site/1e6/1e12) + " x 10^12 (1/cm^3)")
    print("")
    print("overlap integral per lattice site / (molecule num/particle_per_site)= " + str(overlap_int_per_site/1e6/1e12/(n_mol/particle_per_site)) + " x 10^12 (1/cm^3)")
    print("")
    
    cross_sec = thermo_rate * (A * nanum_avg * n_mol * (m_na * 10**-3)/(6.02*10**23) * na_omega_mean**3/(np.pi**2 * 1.38*10**-23 * natemp_init * 10**-6))**-1
    cross_sec_per_site = cross_sec * (particle_per_site)**2
    err_cross_sec_per_site = np.sqrt(((cross_sec_per_site/thermo_rate) * thermo_rate_err)**2 + (cross_sec_per_site/natemp_init * natemp_init_std)**2 + (cross_sec_per_site/n_mol * n_mol_err)**2 +  ((cross_sec_per_site/nanum_avg) * nanum_std)**2)

#    print("Na temp freqeuncy for 4V = " + str(na_trapfreq_4V))

    bohr_radius = 0.529 * 10**-10 ##m
    print("Na num = " + str(nanum_avg))
    print("mol num = " + str(n_mol))
    print("Na temp = " + str(natemp_init))
    print("mol temp = " + str(init_temp))
    print("thermo rate = " + str(thermo_rate))
    print("the collision cross section = " + str(cross_sec_per_site * 1e4) + " cm^2")
    print("the collision cross section from overlap integral = " + str(thermo_rate/(v_rel * overlap_int_per_site) *1e4) + " cm^2")
    print("the collision cross section std = " + str(err_cross_sec_per_site *1e4) + " 1/cm^2")
#    print("the collision cross section '2' std = " + str(err_cross_sec_per_site_2 *1e4) + " 1/cm^2")
    print("")
    print(" ---------- correction ----------- ")
    print("")
    xi = 4. * m_na * m_mol * (m_na + m_mol)**-2
#    coll_factor = (xi/3.) 
    coll_factor = 3./xi
    old = cross_sec_per_site
    err_old= err_cross_sec_per_site
    
    cross_sec_per_site = old * coll_factor * n_mol/particle_per_site
    err_cross_sec_per_site = cross_sec_per_site * np.sqrt((n_mol_err/n_mol)**2 + (err_old/old)**2)
    err_cross_sec_per_site_2 = cross_sec_per_site * np.sqrt((thermo_rate_err/thermo_rate)**2 + (n_mol_err/n_mol)**2 + (err_v_rel/v_rel)**2 + (err_overlap_int_per_site/overlap_int_per_site)**2)

    print("the collision cross section = " + str(cross_sec_per_site * 1e4) + " cm^2")
    print("the collision cross section std = " + str(err_cross_sec_per_site * 1e4) + " 1/cm^2")
    print("the collision cross section '2' std = " + str(err_cross_sec_per_site_2 * 1e4) + " 1/cm^2")
    print("std/avg  = " + str(err_cross_sec_per_site/cross_sec_per_site))
    print(" ---------- correction ----------- ")
    print("")
    
    scattering_length = np.sqrt(cross_sec_per_site/(4*np.pi))
#    err_scattering_length = err_cross_sec_per_site/(2 * np.sqrt(np.pi * cross_sec_per_site))
    err_scattering_length = err_cross_sec_per_site_2/(2 * np.sqrt(np.pi * cross_sec_per_site))
    print("scattering length = " + str(scattering_length *1e9) + " nm")
    print("scattering length std = " + str(err_scattering_length *1e9) + " nm")
    print("scattering length = " + str(scattering_length/bohr_radius) + " a_o")
    print("scattering length std = " + str(err_scattering_length/bohr_radius) + " a_o")
    
    ################################################
    #### Uncertainty of the scattering length ######
    ################################################
    


    #########################################
    ##### one body loss fit with Na ######
    #########################################
    def onebodyloss(t, num, l):
        return num * np.exp(-l * t)

    b = ([18000, 0], [25000, 10])
    popt_one, pcov_one = curve_fit(onebodyloss, t_molnum, n_molnum, p0 =[20000., 0.1],  bounds=b, sigma = err_molnum)
    plt.plot(t, onebodyloss(t, popt_one[0], popt_one[1]), 'r--')

    n_mol = popt_one[0]
    gamma = popt_one[1]
    
    err = np.sqrt(np.diag(pcov_one))
    n_mol_err = err[0]
    gamma_err = err[1]
    print()
    print("mol num / std = " + str(n_mol) + " / " + str(n_mol_err))
    print("gamma /std = " + str(gamma) + " / " + str(gamma_err))
    print()
    print(nanum_avg)
    
    elastic_rate = coll_factor * thermo_rate
    err_elastic_rate = coll_factor * thermo_rate_err
    r = elastic_rate/gamma
    err_r = r * np.sqrt((gamma_err/gamma)**2 + (thermo_rate_err/thermo_rate)**2)
    print("")
    print("xi = " +str(xi))
    print("collision factor = " + str(coll_factor))
    print("1/collision factor = " + str(1./coll_factor))
    print("elastic collision rate = " + str(elastic_rate) + " / std = " + str(err_elastic_rate))
    print("")
    print("the ratio of thermalization rate/one-body loss rate = " + str(r))
    print("std = " + str(err_r))
    print("")
    
    
    ratediff = gamma - twobodyrate    
    ratediff_err = np.sqrt(twobodyrate_err**2 + gamma_err**2)
    print("(the one body rate - two body rate) / std = " + str(ratediff) + " / " + str(ratediff_err))
    print("std/avg * 100 = " + str(ratediff_err/ratediff * 100.))
    print("")
    print("rate diff + 1 sigma = " + str(ratediff + ratediff_err))
    print("rate diff - 1 sigma = " + str(ratediff - ratediff_err))
    print("")
    print("")
    ###########################################################################
    lossRateCoeff = ratediff/na_density_4V
    lossRateCoeff_err = lossRateCoeff * np.sqrt((ratediff_err/ratediff)**2 + (err_na_density_4V/na_density_4V)**2)
    print("")
    print("rate diff normalized by the Na density = "  + str(lossRateCoeff) + " (cm^3 / s)")
    print("rate diff normalized by the Na density ERR = " + str(lossRateCoeff_err) + " (cm^3 / s)")
    print("std/avg = " + str(lossRateCoeff_err/lossRateCoeff))

    eff_lossRateCoeff = ratediff/(eff_na_density/1e6)
    eff_lossRateCoeff_err = eff_lossRateCoeff * np.sqrt((ratediff_err/ratediff)**2 + (err_eff_na_density/eff_na_density)**2)
    print("")
    print("rate diff normalized by 'effective' overlap Na density = "  + str(eff_lossRateCoeff) + " (cm^3 / s)")
    print("rate diff normalized by 'effective' overlap Na density ERR = " + str(eff_lossRateCoeff_err) + " (cm^3 / s)")
    print("std/avg = " + str(eff_lossRateCoeff_err/eff_lossRateCoeff))
    print("")
    ###########################################################################
    ## Universal rate coeff calculation ##
    mu = 1./(1./m_mol + 1./m_na) * 10**-3/(6.02*10**23)
    h_bar = 1.0545718 * 10**(-34)
    num_factor = 0.477999
    
    a_bohr = 5.29 * 10**-11
    hartree = (2.6355 * 10**6)/(6.02*10**23) ## joule
#    epsilon = 8.854188 * 10**-12
    C_6_NaLi = 1467. * (a_bohr**6 * hartree)
    C_6_NaNa = 1556. * (a_bohr**6 * hartree)
    C_6 = C_6_NaLi + C_6_NaNa 
    
    a_bar = num_factor * (2 * mu * C_6/h_bar**2)**(1./4.)

    reduced_mass = 29*23./(29+23.) * 10**(-3) *(6.02*10**23)**(-1)
    g = np.sqrt(2 * reduced_mass * C_6 * h_bar**-2)
    a = 13.47703 * 10**-9
    r_eff = 1.3947 * np.sqrt(g) - 1.3333 * g**2/a + 0.63732 * g**1.5/a**2
    k_vec = reduced_mass * v_rel/h_bar
    print("C_6 between NaLi & Na = " + str(C_6))
    print("r_eff = " + str(r_eff))
    print("relative k vector = " + str(k_vec))
    print("(k x a)^2 =" + str((k_vec * a)**2))
    print("k^2 x a x r_eff = " + str(k_vec**2 * a * r_eff))
    print("Energy factor in denominator of ealstic cross-section = " + str((1-k_vec**2 * a * r_eff)**2 + (k_vec * a)**2))
    print("")
    #    alpha = (r_eff/2.) * k_vec**2
#    s = collision_crosssection * 10**(-4)
#    scatt_length_exact = (2 * alpha * s - np.sqrt(4*alpha**2*s**2 - 4*s*(s*alpha**2 + s * k_vec**2 - 4*np.pi)))/(2 * (s * alpha**2 + s *k_vec**2 - 4*np.pi))
#    print("NaLi-Na scattering length EXACT = " + str(scatt_length_exact * 1e9) + " nm")
    print("")
    ## s-wave universal loss coefficient
    K_swave = (4 * np.pi * h_bar/mu) * a_bar * 10**6
    print("s-wave universal loss coefficient = " + str(K_swave) +  " (cm^3 / s)")
    print("(s_wave rate coff/loss rate coeff) = " + str(K_swave/lossRateCoeff))
    print("(s_wave rate coff/eff loss rate coeff) = " + str(K_swave/eff_lossRateCoeff))
    ###########################################################################
    ratio = elastic_rate/ratediff
    print("(thermalization rate/differential loss rate) = " + str(ratio))
    ratio_err = np.abs(ratio) * np.sqrt((err_elastic_rate/elastic_rate)**2 + (ratediff_err/ratediff)**2)
    print("ratio std = " + str(ratio_err))
    print("std/avg * 100 = " + str(ratio_err/ratio * 100))
    print("")
    print("min ratio ~ ratio - 1 sigma = " + str(ratio - ratio_err))
    print("max ratio ~ ratio + 1 sigma = " + str(ratio + ratio_err))
    
    plt.xlabel("Hold (sec)")
    plt.ylabel("molecule Number")
        
    plt.savefig(join(plotpath, "GS mol lifetime.pdf"))
    
    plt.xlim([0, 1.1])
    plt.close()
  

    #############################################################

    fig = plt.figure()
    conversion = 2.54
    ## single column = 8.9 cm
    fig.set_size_inches(8.9/conversion, 8.9/conversion)

    grid = plt.GridSpec(2,7, wspace = 0.3, hspace = 0.31)
    ax1 = fig.add_subplot(grid[0, 0:3], facecolor='w')
    ax2 = fig.add_subplot(grid[0, 4:], facecolor='w')
    ax3 = fig.add_subplot(grid[1, 1:4], facecolor='w')
    ax4 = fig.add_subplot(grid[1, 4:], facecolor='w')
    
    ax = [ax1, ax2, ax3, ax4]
    ax1.tick_params(axis='both',  which='both')
    ax2.tick_params(axis='both',  which='both')
    ax3.tick_params(axis='both',  which='major')

#    lw = 1.2
    lw = 1.
#    ms = 5.65
    ms = 3.8
#    fmt_list = ['o', 'D', 's']
#    color_list=['#0000ff','#ff0000', '#000000']
    
    fmt_list = ['o', 's', 'x']
    facecolor_list = ['white','#ff0000', '#000000']
    color_list=['#0000ff','#ff0000', '#000000']
    
    
    toPlot = []
    toPlot.append(moltemp_no)
    toPlot.append(moltemp)
    toPlot.append(natemp)

    for ax_i in range(len(ax)):
        if ax_i < 2:
            continue
        
        i = 0
        for data in toPlot:
#            if i == 1:
#                i += 1
#                continue
            x = data[0]
            y = data[1]
            err = data[2]
            
            (_, caps, _) = ax[ax_i].errorbar(x,y, yerr = err, fmt = fmt_list[i], mew=lw, capsize=0, markersize=ms, c = color_list[i], markerfacecolor = facecolor_list[i]) ##
            
            i += 1
        for cap in caps:
        #cap.set_color('red')
            cap.set_markeredgewidth(0.)
            
    ax[3].legend(['NaLi without Na', 'NaLi with Na', 'Na'], loc='upper right',  handlelength = 1.0, frameon = True, labelspacing = 0.2, facecolor = 'w', borderpad = 0.1, edgecolor ='w',  fontsize = 7, framealpha = 0., markerscale = 1)

    ###################################################
    ########### SIMPLE EXPONENTIAL FIT ###########
    ###################################################
    ax[2].plot(t, onebodyloss_offset(t, popt_one_off[0], popt_one_off[1], popt_one_off[2]), 'r--')        
    ax[3].plot(t, onebodyloss_offset(t, popt_one_off[0], popt_one_off[1], popt_one_off[2]), 'r--')        

    ###################################################
    ########### NUMERICAL SIMULATION FIT ###########
    ###################################################
#    dT0 = 0.670798
#    dTf = 1.204759
#    a_min = 261. * a_bohr
#    t = np.linspace(0, 1.2, 100)
#    
#    dTs = odeint(d_delT_dt, dT0, t, args = (a_min, dTf,))
#    ax[2].plot(t, dTs+dTf, 'r--')        
#    ax[3].plot(t, dTs+dTf, 'r--')        
#    flg = False

#    if (flg == True):
#        ax[0].sheatinget_xscale('log')
#        ax[0].set_xticks([0.02, 0.1, 1])
#        ax[0].get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())

    ylabel = "Temperature " + str(r"$(\mu\mathrm{K})$")
    ax[2].set_ylabel(ylabel)
    ax[2].set_ylim([.9, 2.75])
    ax[3].set_ylim([.9, 2.75])

    ax[2].set_ylim([.9, 2.5])
    ax[3].set_ylim([.9, 2.5])
    
    ax[2].set_xlim([-0.0, 0.25])
    ax[3].set_xlim([0.65, 1.05])
   
  
    toPlot = []
    toPlot.append(molnum_no)
    toPlot.append(molnum)
    toPlot.append(nanum)
    num_factor = 1e4
    i = 0
    
    for data in toPlot:
        x = data[0]
        y = data[1]/num_factor
        err = data[2]/num_factor
        
        (_, caps, _) = ax2.errorbar(x,y, yerr = err, fmt = fmt_list[i], mew=lw, capsize=0, markersize=ms, c = color_list[i], markerfacecolor = facecolor_list[i]) ##
        i += 1
    
    for cap in caps:
    #cap.set_color('red')
        cap.set_markeredgewidth(0.)
        
    plotFlg == 0
    if (plotFlg == 1):
        t = np.linspace(0, 1.2, 5000)
        moltemp_ulimit= (moltemp_avg + moltemp_std)*np.ones(5000)
        moltemp_dlimit = (moltemp_avg - moltemp_std)*np.ones(5000)
        ax[2].fill_between(t, moltemp_ulimit, moltemp_dlimit, facecolor='green', alpha = 0.5)      
        templine = np.linspace(moltemp_min, moltemp_max, 10)
        natemp_ulimit= (natemp_avg + natemp_std)*np.ones(5000)
        natemp_dlimit = (natemp_avg - natemp_std)*np.ones(5000)
        ax[2].fill_between(t, natemp_ulimit, natemp_dlimit, facecolor='yellow', alpha = 0.5)

    ax[1].legend(['without Na', 'with Na'], handlelength = 1.0, loc='upper right', frameon = True, labelspacing = 0.2, facecolor = 'w', borderpad = 0.1, edgecolor ='w',  fontsize = 7, framealpha = 0., markerscale = 1)    

    ax[1].plot(t, twobodyloss(t, popt_two[0], popt_two[1])/num_factor, 'b')
    ax[1].plot(t, onebodyloss(t, popt_one[0], popt_one[1])/num_factor, 'r--')
    
    ax[1].yaxis.major.formatter.set_powerlimits((0,0))
    ax[1].set_xlim([0, 1.1])
    ax[1].set_ylim([-0.1, 2.7])
    ax[1].set_yticks([0., 1., 2.])

    s = 7
    fig.text(.57, .0375, 'Hold time (s)', ha='center', rotation='horizontal', fontsize=s)
    ax[1].set_xlabel("Hold time (s)")
    ax[1].set_ylabel("Molecule num. (" + str(r"$\times 10^{4}$") + ")")
    
    for i in [1, 2, 3]:
        if i == 1:
            ax[i].yaxis.set_ticks_position('both')
        if i == 2:
            ax[i].yaxis.set_ticks_position('left')
        if i == 3:
            ax[i].yaxis.set_ticks_position('right')
      
        ax[i].xaxis.set_ticks_position('both')
        ax[i].tick_params(which = 'both', direction='in')  

    d = .015 # how big to make the diagonal lines in axes coordinates
    # arguments to pass plot, just so we don't keep repeating them
    kwargs = dict(transform=ax3.transAxes, color='k', clip_on=False)
    ax3.plot((1-d,1+d), (-d,+d), **kwargs)
    ax3.plot((1-d,1+d),(1-d,1+d), **kwargs)
    
    kwargs.update(transform=ax4.transAxes)  # switch to the bottom axes
    ax4.plot((-d,+d), (1-d,1+d), **kwargs)
    ax4.plot((-d,+d), (-d,+d), **kwargs)
    
    cartoon = plt.imread(join(working_dir, "cartoon_thick.jpg"))

    ax1.imshow(cartoon)
    ax1.set_yticklabels([])
    ax1.set_xticklabels([])
    ax1.set_axis_off()

    ax3.spines['right'].set_visible(False)
#    ax3.yaxis.tick_left()
#    ax3.tick_params(labelleft = 'on')
#    ax3.tick_params(labelright = 'off')

    ax4.spines['left'].set_visible(False)    
    ax4.set_yticklabels([])
#    ax4.yaxis.tick_right()
#    ax4.tick_params(labelleft = 'off')
#    ax4.tick_params(labelright = 'off')

    s = 8
    fig.canvas.draw()
    labels = [r'$\textbf{a.}', r'$\textbf{b.}', r'$\textbf{c.}', '']
    axlabels = [fig.text(0,0, label, fontsize=s, fontweight="bold", va="top", ha="left") for a, label in zip(ax, labels)]

    def update_labels(evt=None):
        trans = fig.transFigure.inverted()
        i = 0
        for a, label in zip(ax, axlabels):
            bbox = a.get_tightbbox(fig.canvas.get_renderer())
            x = bbox.x0
            y = bbox.y1
            if i == 1: 
                x -= 20
                
            label.set_position(trans.transform_point([x, y]))
            
            i += 1
    update_labels()

    plt.savefig(join(plotpath, "Thermalization_comb_break.pdf"), bbox_inches='tight', pad_inches=0.01, dpi = 1000)
#    plt.savefig(join(plotpath, "Thermalization_comb_break.pdf"))